Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Epsilon change in normalise for stability #2421

Merged
merged 3 commits into from
Nov 5, 2024

Conversation

billera
Copy link
Contributor

@billera billera commented Apr 7, 2024

Normalise allows for an optional epsilon term aimed towards improving numerical stability. Previously the epsilon was added after computing the standard deviation of the input. The standard deviation computation involves a square root, leading to NaN's in gradients dependent on normalise when the variance is very low, and for instance LayerNorms applied to low variance inputs will result in NaN gradients. By first computing the variance and taking the square root after adding epsilon^2 (squaring to preserve scale), we prevent NaN's in gradients at low variance. See the following example with LayerNorm in the current patch.

using Flux 
using Zygote 

ln = LayerNorm(256; eps = 1f-3)
for i in 1:10 
    x = ones(Float32, 256) .+ randn(Float32, 256) .* 10f0^(-i)
    l, gs = Zygote.withjacobian(ln, x)
    @show maximum(gs[1])
end


>>> maximum(gs[1]) = 9.44178f0
>>> maximum(gs[1]) = 95.85736f0
>>> maximum(gs[1]) = 477.4946f0
>>> maximum(gs[1]) = 910.05457f0
>>> maximum(gs[1]) = 985.8402f0
>>> maximum(gs[1]) = 995.0282f0
>>> maximum(gs[1]) = 995.9835f0
>>> maximum(gs[1]) = NaN32
>>> maximum(gs[1]) = NaN32
>>> maximum(gs[1]) = NaN32

We observe that while the gradients are fixed at low variance due to the epsilon addition in the denominator, this does prevent NaN's, due to the non-padded square root in the std computation. But, when using the updated normalise, these NaN's dissapear,

>>> maximum(gs[1]) = 9.531697f0
>>> maximum(gs[1]) = 105.468056f0
>>> maximum(gs[1]) = 674.7051f0
>>> maximum(gs[1]) = 991.67163f0
>>> maximum(gs[1]) = 996.03973f0
>>> maximum(gs[1]) = 996.09314f0
>>> maximum(gs[1]) = 996.0937f0
>>> maximum(gs[1]) = 996.0937f0
>>> maximum(gs[1]) = 996.0937f0
>>> maximum(gs[1]) = 996.0937f0

and remain fixed to the implicitly capped value. A simple test verifying this computation's equivalence with the previous one (modulo the differences at very low standard deviations) could be added if desired.

@billera billera closed this Apr 7, 2024
@billera billera reopened this Apr 7, 2024
Copy link

codecov bot commented Apr 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 60.45%. Comparing base (e1989b5) to head (fbc1186).
Report is 2 commits behind head on master.

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2421       +/-   ##
===========================================
+ Coverage   33.50%   60.45%   +26.94%     
===========================================
  Files          31       31               
  Lines        1910     1942       +32     
===========================================
+ Hits          640     1174      +534     
+ Misses       1270      768      -502     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CarloLucibello
Copy link
Member

I agree with this change and pytorch does the same thing. It should considered a breaking change thous, so let's wait for when we are near v01.5 before merging.

@CarloLucibello CarloLucibello added this to the v0.15 milestone Apr 7, 2024
Co-authored-by: Carlo Lucibello <[email protected]>
@mcabbott
Copy link
Member

mcabbott commented Apr 9, 2024

Can this have a test with input which triggers the NaN behaviour before?

Ideally testing not just the function, but also LayerNorm, maybe BatchNorm, anything which uses this internally. Then if the implementation of these layers finally gets replaced, it will be harder to lose the change.

@ToucheSir
Copy link
Member

Putting a backlink to #2096 because this work should close that.

@CarloLucibello CarloLucibello merged commit 91f2d47 into FluxML:master Nov 5, 2024
5 of 9 checks passed
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs news entry?

@inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5))
@inline function normalise(x::AbstractArray; dims=ndims(x), eps=1f-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this now assume Float32? Elsewhere we try to allow for Float16 too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants